import json
import os
from collections import defaultdict
import numpy as np
from utils import geoquery_dataset

def count_left_parentheses(text):
    return text.count('(')

def calculate_accuracy_mini(predictions, gold_data, split_type):
    correct = 0
    total = len(predictions)
    paren_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
    
    for pred, gold in zip(predictions, gold_data):
        answer = ' ' + gold['anonymized']
        num_parens = count_left_parentheses(answer)

        group = 0
        if split_type == 'standard':
            if num_parens in [3, 4, 5]:
                group = 1
            elif num_parens > 5:
                group = 2
        elif split_type == 'tmcd':
            if num_parens in [2, 3, 4]:
                group = 1
            elif num_parens > 4:
                group = 2
        elif split_type == 'template':
            if num_parens in [3, 4]:
                group = 1
            elif num_parens > 4:
                group = 2
        elif split_type == 'length':
            if num_parens in [3, 4]:
                group = 1
            elif num_parens == 5:
                group = 2
            elif num_parens >= 6:
                group = 3
            
        if group > 0:
            paren_stats[group]['total'] += 1
            if pred['answer'] == answer:
                correct += 1
                paren_stats[group]['correct'] += 1
    
    accuracy = correct / total if total > 0 else 0
    
    paren_accuracies = {}
    for group, stats in paren_stats.items():
        acc = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        paren_accuracies[group] = acc
    
    return accuracy, paren_accuracies

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

def calculate_accuracy(predictions, gold_data):
    correct = 0
    total = len(predictions)
    for pred, gold in zip(predictions, gold_data):
        answer = ' ' + gold['anonymized']
        if pred['answer'] == answer:
            correct += 1
                
    accuracy = correct / total if total > 0 else 0
    return accuracy

def print_paren_old(splits, k_values, paren_counts, paren_results,):
    print("\n=== Accuracy Statistics for Different Parenthesis Counts ===")
    for split_type in splits:
        print(f"\n{split_type} split:")
        print("\nParenthesis Count Distribution:")
        for num_parens in sorted(paren_counts[split_type].keys()):
            print(f"Parenthesis Count {num_parens}: {paren_counts[split_type][num_parens]} samples")

        for k in k_values:
            print(f"\nk={k}:")
            for num_parens in sorted(paren_results[split_type]['knn-knn'][k].keys()):
                if paren_results[split_type]['knn-knn'][k][num_parens] and paren_results[split_type]['knn_diversity-knn'][k][num_parens]:
                    mean_acc = -np.mean(paren_results[split_type]['knn-knn'][k][num_parens]) + np.mean(paren_results[split_type]['knn_diversity-knn'][k][num_parens])
                    print(f"Parenthesis Count {num_parens}  knn_diversity - knn: {mean_acc:.4f}")
                    print('knn-knn: ', paren_results[split_type]['knn-knn'][k][num_parens][0])
                    print('knn_diversity-knn: ', paren_results[split_type]['knn_diversity-knn'][k][num_parens][0])
                    
                    for num in sorted(paren_counts[split_type].keys()):
                        group = 0
                        if split_type == 'standard':
                            if num in [3, 4]:
                                group = 1
                            elif num == 5:
                                group = 2
                            elif num > 5:
                                group = 3
                        elif split_type == 'tmcd':
                            if num in [2, 3, 4]:
                                group = 1
                            elif num == 5:
                                group = 2
                            elif num > 5:
                                group = 3
                        elif split_type == 'template':
                            if num in [3, 4]:
                                group = 1
                            elif num == 5:
                                group = 2
                            elif num > 5:
                                group = 3
                        elif split_type == 'length':
                            if num in [3, 4]:
                                group = 1
                            elif num == 5:
                                group = 2
                            elif num > 5:
                                group = 3

def print_paren(splits, k_values, paren_counts, paren_results,):
    print("\n=== Accuracy Statistics for Different Parenthesis Counts ===")
    for split_type in splits:
        print(f"\n{split_type} split:")
        print("\nParenthesis Count Distribution:")
        for num_parens in sorted(paren_counts[split_type].keys()):
            print(f"Parenthesis Count {num_parens}: {paren_counts[split_type][num_parens]} samples")

        for k in k_values:
            print(f"\nk={k}:")
            for num_parens in sorted(paren_results[split_type]['knn-knn'][k].keys()):
                if paren_results[split_type]['knn-knn'][k][num_parens] and paren_results[split_type]['knn_diversity-knn'][k][num_parens]:
                    mean_acc = -np.mean(paren_results[split_type]['knn-knn'][k][num_parens]) + np.mean(paren_results[split_type]['knn_diversity-knn'][k][num_parens])
                    print(f"Parenthesis Count {num_parens}  knn_diversity - knn: {100*mean_acc:.2f}")
                    print(f'knn-knn: {100*paren_results[split_type]["knn-knn"][k][num_parens][0]:.2f}')
                    print(f'knn_diversity-knn: {100*paren_results[split_type]["knn_diversity-knn"][k][num_parens][0]:.2f}')
                    
                    for num in sorted(paren_counts[split_type].keys()):
                        group = 0
                        if split_type == 'standard':
                            if num in [3, 4, 5]:
                                group = 1
                            elif num > 5:
                                group = 2
                        elif split_type == 'tmcd':
                            if num in [2, 3, 4]:
                                group = 1
                            elif num == 5:
                                group = 2
                            elif num > 5:
                                group = 3
                        elif split_type == 'template':
                            if num in [3, 4, 5]:
                                group = 1
                            elif num > 5:
                                group = 2
                        elif split_type == 'length':
                            if num in [3, 4]:
                                group = 1
                            elif num == 5:
                                group = 2
                            elif num > 5:
                                group = 3

def analyze_results():
    models = ['gemma-2-9b']
    embs = ['all-roberta-large-v1']
    
    splits = {
        'standard': ['standard'],
    }
    k_values = range(1,31,1)
    
    for model in models:
        for emb in embs:
            print(f"model_name: {model}")
            print(f"emb_name: {emb}")
            
            results = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
            file_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
            paren_results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
            paren_counts = defaultdict(lambda: defaultdict(int))
            
            methods = [ 'random-knn', 'knn-knn', 'diversity-knn','knn_diversity-knn',]
            for split_type, split_names in splits.items():
                for split_name in split_names:
                    _, gold_data = geoquery_dataset.get_preprocessed_geoquery(split_name)
                    
                    for gold in gold_data:
                        answer = ' ' + gold['anonymized']
                        num_parens = count_left_parentheses(answer)
                        paren_counts[split_type][num_parens] += 1
                    
                    result_dir = f'./results/test/geo880-{split_name}/{model}/{emb}'
                    if not os.path.exists(result_dir):
                        print(f"Warning: Directory not found: {result_dir}")
                        continue
                    
                    for file_name in os.listdir(result_dir):
                        if not file_name.endswith('.jsonl'):
                            continue

                        method = None
                        for m in methods:
                            if file_name.startswith(m):
                                method = m
                                break
                        
                        if method is None:
                            continue

                        try:
                            k = int(file_name.split('-')[-1].split('.')[0])
                            if k not in k_values:
                                continue
                        except:
                            continue
                        predictions = load_jsonl(os.path.join(result_dir, file_name))
                        
                        if len(predictions) != len(gold_data):
                            print(f"Warning: Size mismatch in {file_name}: "
                                f"predictions={len(predictions)}, gold={len(gold_data)}")
                            continue
                        
                        accuracy, paren_accuracies = calculate_accuracy_mini(predictions, gold_data, split_type)
                        results[split_type][method][k].append(accuracy)
                        file_counts[split_type][method][k] += 1
                        for num_parens, acc in paren_accuracies.items():
                            paren_results[split_type][method][k][num_parens].append(acc)

            print("\nAverage accuracy for each method and k:")
            for split_type in splits:
                print(f"\n{split_type} split:")
                for method in methods:
                    print(f"\n{method}:")
                    for k in k_values:
                        if results[split_type][method][k]:
                            mean_acc = np.mean(results[split_type][method][k])
                            std_acc = np.std(results[split_type][method][k])
                            print(f"k={k}: ${100*mean_acc:.2f}_{{{100*std_acc:.2f}}}$, len: {len(results[split_type][method][k])}")
                        else:
                            print(f"k={k}: No results")
            
            print("\n=== Combined Results Across All Splits ===")

if __name__ == "__main__":
    analyze_results()